#include "simjoin.h"

Simjoin::Simjoin(vector <Object> s1, vector <Object> s2, int k)
{
    set1 = s1;
    set2 = s2;
    this->k = k;
}

void Simjoin::join()
{
    initParam();
    c_all = 0;

    vector <pair<int, int> > pairOrder = SetOrder();

    int maxsize = pairOrder.size();
    for (int i = 0; i <= min(k - 1, maxsize - 1); i++) {
        ObjectPair new_pair(set1[pairOrder[i].first], set2[pairOrder[i].second]);
        result_set.push_back(new_pair);
    }

    sort(result_set.begin(), result_set.end());
    double threshold = result_set[k - 1].final_score;    

    for (int pairNumber = k; pairNumber <= pairOrder.size() - 1; pairNumber++)
    {
        Object& x = set1[pairOrder[pairNumber].first];
        Object& y = set2[pairOrder[pairNumber].second];

        if (x.score < threshold || y.score < threshold)  break;  //stop condition
        c_all++;

        if (max(x.data.size(),y.data.size())*(threshold/min(x.score, y.score)) <= min(x.data.size(),y.data.size())) //size filtering
        {
            c_s++;
            int alpha = CountOverlapThreshold(threshold/min(x.score, y.score), x.data.size(), y.data.size());
            int x_prefix = x.data.size() - alpha + 1;
            int y_prefix = y.data.size() - alpha + 1;
            int sharedTokensCount = 0;
            l_pr += x_prefix;
            l_pr += y_prefix;

            for (int i = 0; i <= x_prefix - 1; i++)
            for(int j = 0; j <= y_prefix - 1; j++)
                if (x.data[i] == y.data[j])
                  sharedTokensCount++;

            if (sharedTokensCount!=0)
                threshold = Verify(x, y, alpha, sharedTokensCount);
        }
    }
    l_pr = l_pr/(2*c_s);
    saveStatistic();
}

void Simjoin::join(double time)
{
    initParam();
    c_all = countAllPairs(time);

    vector <pair<int, int> > pairOrder = SetOrder();

    int maxsize = pairOrder.size();
    for (int i = 0; i <= min(k - 1, maxsize - 1); i++) {
        ObjectPair new_pair(set1[pairOrder[i].first], set2[pairOrder[i].second]);
        result_set.push_back(new_pair);
    }

    sort(result_set.begin(), result_set.end());
    double threshold = result_set[k - 1].final_score;

    for (int pairNumber = k; pairNumber <= min(k + c_all - 1, maxsize - 1); pairNumber++)
    {
        Object& x = set1[pairOrder[pairNumber].first];
        Object& y = set2[pairOrder[pairNumber].second];

        if (x.score < threshold || y.score < threshold)  break;  //stop condition

        if (max(x.data.size(),y.data.size())*(threshold/min(x.score, y.score)) <= min(x.data.size(),y.data.size())) //size filtering
        {
            c_s++;
            int alpha = CountOverlapThreshold(threshold/min(x.score, y.score), x.data.size(), y.data.size());
            int x_prefix = x.data.size() - alpha + 1;
            int y_prefix = y.data.size() - alpha + 1;
            int sharedTokensCount = 0;
            l_pr += x_prefix;
            l_pr += y_prefix;

            for (int i = 0; i <= x_prefix - 1; i++)
            for(int j = 0; j <= y_prefix - 1; j++)
                if (x.data[i] == y.data[j])
                  sharedTokensCount++;

            if (sharedTokensCount!=0)
                threshold = Verify(x, y, alpha, sharedTokensCount);
        }
    }

    l_pr = l_pr/(2*c_s);
    saveStatistic();
}

double Simjoin::Verify(Object& x, Object& y, int alpha, int sharedTokensCount)
{
    int x_prefix = x.data.size() - alpha + 1;
    int y_prefix = y.data.size() - alpha + 1;
    c_p++;

    int overlap = sharedTokensCount;
    if (x.data[x_prefix - 1] < y.data[y_prefix - 1])  {  // global order
        vector<Token> x_suffix(x.data.begin() + x_prefix, x.data.end());
        vector<Token> y_suffix(y.data.begin() + sharedTokensCount, y.data.end());
        if (x_suffix.size() > 0 && y_suffix.size() > 0)
        overlap += countOverlap(x_suffix, y_suffix);
    }
    else {
        vector<Token> x_suffix(x.data.begin() + sharedTokensCount, x.data.end());
        vector<Token> y_suffix(y.data.begin() + y_prefix, y.data.end());
        if (x_suffix.size() > 0 && y_suffix.size() > 0)
        overlap += countOverlap(x_suffix, y_suffix);
    }
    if (overlap > alpha)
    {
        ObjectPair new_pair(x, y, overlap);
        if (new_pair.final_score > result_set[k - 1].final_score) {
            c_v++;
            result_set.erase(result_set.end() - 1);
            int i = 0;
            while (result_set[i].final_score >= new_pair.final_score)
                i++;
            result_set.insert(result_set.begin() + i, new_pair);
            return result_set[k - 1].final_score; //update threshold
        }
    }
}

vector <pair<int, int> > Simjoin::SetOrder()
{
    vector <pair<int, int> > pairOrder;

    pairOrder.push_back(pair<int, int>(0,0));
    int i = 1;
    int j = 0;

    while (i <= set2.size() - 1)
    {
        for (j; j <= i - 1; j++)
            pairOrder.push_back(pair<int, int>(i,j));

        for (i; i >= 0; i--)
            pairOrder.push_back(pair<int, int>(i,j));
        i = j + 1;
        j = 0;
    }

    while (i <= set1.size() - 1)
     {
        for (j; j <= set2.size() - 1; j++)
             pairOrder.push_back(pair<int, int>(i,j));
        i++;
        j = 0;
     }

    return pairOrder;
}

int Simjoin::CountOverlapThreshold(double simThreshold, int size_x, int size_y)
{
    return max(1.0, floor(simThreshold*(size_x + size_y)/(1 + simThreshold))); //for Jaccard similarity
}

int Simjoin::countOverlap(vector<Token> object1, vector<Token> object2)
{
    int count1 = 0;
    int count2 = 0;
    int overlap = 0;

    while (count1 <= object1.size() - 1 && count2 <= object2.size() - 1) {
        if (object1[count1] < object2[count2])
            count1++;
        else if (object2[count2] < object1[count1])
            count2++;
        else {
            if (object1[count1] == object2[count2])
                overlap++;
            count1++;
            count2++;
        }
    }

    return overlap;
}

void Simjoin::initParam() {
    c_s = 0;
    c_p = 0;
    c_v = 0;

    l_pr = 0;
    l_ob = 0;
    for (int i = 0; i <= set1.size() - 1; i++)
        l_ob += set1[i].data.size();
    for (int i = 0; i <= set2.size() - 1; i++)
        l_ob += set2[i].data.size();
    l_ob = l_ob/(set1.size() + set2.size());
}

void Simjoin::saveStatistic()
{
    Statistic::instance().w_stat.push_back((double)c_all/(set1.size()*set2.size() - k));
    Statistic::instance().x_stat.push_back((double)c_s/c_all);
    Statistic::instance().y_stat.push_back((double)c_p/c_s);
    Statistic::instance().z_stat.push_back((double)c_v/c_p);
    Statistic::instance().p_stat.push_back((double)l_pr/l_ob);
}

int Simjoin::countAllPairs(double time)
{    
    double T = 0.8/1000000;
    double x, y, z, a, p, w;

    if (Statistic::instance().x_stat.size() == 0) {
        w = x = y = z = p = 0.5;
        a = - 1;
    }
    else {        
        w = 0;
        x = 0;
        y = 0;
        z = 0;

        accumulate(Statistic::instance().w_stat.begin(), Statistic::instance().w_stat.end(), w);
        w = w/Statistic::instance().w_stat.size();
        accumulate(Statistic::instance().x_stat.begin(), Statistic::instance().x_stat.end(), x);
        x = x/Statistic::instance().x_stat.size();
        accumulate(Statistic::instance().y_stat.begin(), Statistic::instance().y_stat.end(), y);
        y = y/Statistic::instance().y_stat.size();
        accumulate(Statistic::instance().z_stat.begin(), Statistic::instance().z_stat.end(), z);
        z = z/Statistic::instance().z_stat.size();
        accumulate(Statistic::instance().p_stat.begin(), Statistic::instance().p_stat.end(), p);
        p = p/Statistic::instance().p_stat.size();
        if (Statistic::instance().a_stat.size() != 0) {
        accumulate(Statistic::instance().a_stat.begin(), Statistic::instance().a_stat.end(), a);
        a = a/Statistic::instance().a_stat.size();
        }
        else a = -1;
    }

    c_all = floor((time/T - k*(5 + 2*l_ob +  log10(k)/log10(2)))/(3 + x*(p*p*l_ob*l_ob + 1) +
                                                                 x*y*((2 - p)*l_ob + 2) + x*y*z*(2 + k)));

    double qual = 1 - exp(a*c_all/(w*(set1.size()*set2.size() - k)));
    cout<<"expected quality is "<<qual<<endl;

    return c_all;
}
